{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Obtain the SELU parameters for arbitrary fixed points\n",
    "\n",
    "*Author:* Guenter Klambauer, 2017\n",
    "\n",
    "tested under Python 3.5\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.special import erf,erfc\n",
    "from sympy import Symbol, solve, nsolve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Function to obtain the parameters for the SELU with arbitrary fixed point (mean variance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def getSeluParameters(fixedpointMean=0,fixedpointVar=1):\n",
    "    \"\"\" Finding the parameters of the SELU activation function. The function returns alpha and lambda for the desired fixed point. \"\"\"\n",
    "    \n",
    "    import sympy\n",
    "    from sympy import Symbol, solve, nsolve\n",
    "\n",
    "    aa = Symbol('aa')\n",
    "    ll = Symbol('ll')\n",
    "    nu = fixedpointMean \n",
    "    tau = fixedpointVar \n",
    "\n",
    "    mean =  0.5*ll*(nu + np.exp(-nu**2/(2*tau))*np.sqrt(2/np.pi)*np.sqrt(tau) + \\\n",
    "                        nu*erf(nu/(np.sqrt(2*tau))) - aa*erfc(nu/(np.sqrt(2*tau))) + \\\n",
    "                        np.exp(nu+tau/2)*aa*erfc((nu+tau)/(np.sqrt(2*tau))))\n",
    "\n",
    "    var = 0.5*ll**2*(np.exp(-nu**2/(2*tau))*np.sqrt(2/np.pi*tau)*nu + (nu**2+tau)* \\\n",
    "                          (1+erf(nu/(np.sqrt(2*tau)))) + aa**2 *erfc(nu/(np.sqrt(2*tau))) \\\n",
    "                          - aa**2 * 2 *np.exp(nu+tau/2)*erfc((nu+tau)/(np.sqrt(2*tau)))+ \\\n",
    "                          aa**2*np.exp(2*(nu+tau))*erfc((nu+2*tau)/(np.sqrt(2*tau))) ) - mean**2\n",
    "\n",
    "    eq1 = mean - nu\n",
    "    eq2 = var - tau\n",
    "\n",
    "    res = nsolve( (eq2, eq1), (aa,ll), (1.67,1.05))\n",
    "    return float(res[0]),float(res[1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.6732632423543774, 1.0507009873554802)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "### To recover the parameters of the SELU with mean zero and unit variance\n",
    "getSeluParameters(0,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1.9769021954241999, 1.073851239616047)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "### To obtain new parameters for mean zero and variance 2\n",
    "myFixedPointMean = -0.1\n",
    "myFixedPointVar = 2.0\n",
    "myAlpha, myLambda = getSeluParameters(myFixedPointMean,myFixedPointVar)\n",
    "getSeluParameters(myFixedPointMean,myFixedPointVar)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Adjust the SELU function and Dropout to your new parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def selu(x):\n",
    "    with ops.name_scope('elu') as scope:\n",
    "        alpha = myAlpha\n",
    "        scale = myLambda\n",
    "        return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def dropout_selu(x, rate, alpha= -myAlpha*myLambda, fixedPointMean=myFixedPointMean, fixedPointVar=myFixedPointVar, \n",
    "                 noise_shape=None, seed=None, name=None, training=False):\n",
    "    \"\"\"Dropout to a value with rescaling.\"\"\"\n",
    "\n",
    "    def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name):\n",
    "        keep_prob = 1.0 - rate\n",
    "        x = ops.convert_to_tensor(x, name=\"x\")\n",
    "        if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1:\n",
    "            raise ValueError(\"keep_prob must be a scalar tensor or a float in the \"\n",
    "                                             \"range (0, 1], got %g\" % keep_prob)\n",
    "        keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name=\"keep_prob\")\n",
    "        keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())\n",
    "\n",
    "        alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name=\"alpha\")\n",
    "        keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())\n",
    "\n",
    "        if tensor_util.constant_value(keep_prob) == 1:\n",
    "            return x\n",
    "\n",
    "        noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x)\n",
    "        random_tensor = keep_prob\n",
    "        random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype)\n",
    "        binary_tensor = math_ops.floor(random_tensor)\n",
    "        ret = x * binary_tensor + alpha * (1-binary_tensor)\n",
    "\n",
    "\n",
    "        a = tf.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * tf.pow(alpha-fixedPointMean,2) + fixedPointVar)))\n",
    "        \n",
    "        b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha)\n",
    "        ret = a * ret + b\n",
    "        ret.set_shape(x.get_shape())\n",
    "        return ret\n",
    "\n",
    "    with ops.name_scope(name, \"dropout\", [x]) as name:\n",
    "        return utils.smart_cond(training,\n",
    "            lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name),\n",
    "            lambda: array_ops.identity(x))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean/var should be at: -0.1 / 2.0\n",
      "Input data mean/var:   -0.081669516861 / 1.985695838928\n",
      "After selu:            -0.086086399853 / 2.011691331863\n",
      "After dropout mean/var -0.080544397235 / 2.031569242477\n"
     ]
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "from __future__ import absolute_import, division, print_function\n",
    "import numbers\n",
    "from tensorflow.contrib import layers\n",
    "from tensorflow.python.framework import ops\n",
    "from tensorflow.python.framework import tensor_shape\n",
    "from tensorflow.python.framework import tensor_util\n",
    "from tensorflow.python.ops import math_ops\n",
    "from tensorflow.python.ops import random_ops\n",
    "from tensorflow.python.ops import array_ops\n",
    "from tensorflow.python.layers import utils\n",
    "\n",
    "\n",
    "x = tf.Variable(tf.random_normal([10000],mean=myFixedPointMean, stddev=np.sqrt(myFixedPointVar)))\n",
    "w = selu(x)\n",
    "y = dropout_selu(w,0.2,training=True)\n",
    "init = tf.global_variables_initializer()\n",
    "                \n",
    "gpu_options = tf.GPUOptions(allow_growth=True)\n",
    "with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:\n",
    "    sess.run(init)\n",
    "    z,zz, zzz = sess.run([x, w, y]) \n",
    "    #print(z)\n",
    "    #print(zz)\n",
    "    print(\"mean/var should be at:\", myFixedPointMean, \"/\", myFixedPointVar)\n",
    "    print(\"Input data mean/var:  \", \"{:.12f}\".format(np.mean(z)), \"/\", \"{:.12f}\".format(np.var(z)))    \n",
    "    print(\"After selu:           \", \"{:.12f}\".format(np.mean(zz)), \"/\", \"{:.12f}\".format(np.var(zz)))\n",
    "    print(\"After dropout mean/var\", \"{:.12f}\".format(np.mean(zzz)), \"/\", \"{:.12f}\".format(np.var(zzz)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### For completeness: These are the correct expressions for mean zero and unit variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "myAlpha = -np.sqrt(2/np.pi) / (np.exp(0.5) * erfc(1/np.sqrt(2))-1 )  \n",
    "myLambda = (1-np.sqrt(np.exp(1))*erfc(1/np.sqrt(2)))  *  \\\n",
    "            np.sqrt( 2*np.pi/ (2 + np.pi -2*np.sqrt(np.exp(1))*(2+np.pi)*erfc(1/np.sqrt(2)) + \\\n",
    "            np.exp(1)*np.pi*erfc(1/np.sqrt(2))**2 + 2*np.exp(2)*erfc(np.sqrt(2))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Alpha parameter of the SELU:  1.67326324235\n",
      "Lambda parameter of the SELU:  1.05070098736\n"
     ]
    }
   ],
   "source": [
    "print(\"Alpha parameter of the SELU: \", myAlpha)\n",
    "print(\"Lambda parameter of the SELU: \", myLambda)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf-alpha",
   "language": "python",
   "name": "tf-alpha"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
